{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST with torchvision and skorch\n",
    "\n",
    "This notebooks shows how to define and train a simple Neural-Network with PyTorch and use it via skorch with the help of torchvision.\n",
    "\n",
    "<table align=\"left\"><td>\n",
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb\">\n",
    "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>  \n",
    "</td><td>\n",
    "<a target=\"_blank\" href=\"https://github.com/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note**: If you are running this in [a colab notebook](https://colab.research.google.com/github/skorch-dev/skorch/blob/master/notebooks/MNIST.ipynb), we recommend you enable a free GPU by going:\n",
    "\n",
    "> **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n",
    "\n",
    "If you are running in colab, you should install the dependencies and download the dataset by running the following cell:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "! [ ! -z \"$COLAB_GPU\" ] && pip install torch scikit-learn==0.21.* skorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import islice\n",
    "\n",
    "from sklearn.model_selection import train_test_split\n",
    "import torch\n",
    "import torchvision\n",
    "from torchvision.datasets import MNIST\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_TENSORBOARD = True  # whether to use TensorBoard\n",
    "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "MNIST_FLAT_DIM = 28 * 28"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Data\n",
    "\n",
    "Use torchvision's data repository to provide MNIST data in form of a torch `Dataset`. Originally, the `MNIST` dataset provides 28x28 `PIL` images. To use them with PyTorch, we convert those to tensors by adding the `ToTensor` transform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_train = MNIST('datasets', train=True, download=True, transform=torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_test = MNIST('datasets', train=False, download=True, transform=torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Taking a look at the data\n",
    "\n",
    "Each entry in the `mnist_train` and `mnist_test` Dataset instances consists of a 28 x 28 images and the corresponding label (numbers between 0 and 9). The image data is already normalized to the range [0; 1]. Let's take a look at the first 5 images of the training set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_example, y_example = zip(*islice(iter(mnist_train), 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.), tensor(1.))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_example[0].min(), X_example[0].max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Print a selection of training images and their labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_example(X, y, n=5):\n",
    "    \"\"\"Plot the images in X and their labels in rows of `n` elements.\"\"\"\n",
    "    fig = plt.figure()\n",
    "    rows = len(X) // n + 1\n",
    "    for i, (img, y) in enumerate(zip(X, y)):\n",
    "        ax = fig.add_subplot(rows, n, i + 1)\n",
    "        ax.imshow(img.reshape(28, 28))\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_title(y)\n",
    "    plt.tight_layout()\n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAABnCAYAAABVe9YVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAATc0lEQVR4nO3deXhV1bnH8e/KABmYDBoEyyQQIjiggooDaEW0LdehiohaubQ+FrhQpdLS+thBxVttrVQpzmVQW/U600FtUeSxiigWrVUmUSLIIFOYw3Cy7x/rnLUOnhMM8SR7n+T3eZ48LN/sffJme07WXmuvwQRBgIiISNTkhJ2AiIhIOqqgREQkklRBiYhIJKmCEhGRSFIFJSIikaQKSkREIkkVlIiIRFLWVVDGmFeNMVXGmO3xryVh55TNjDElxphnjTE7jDEVxpjLw86pMTDG9Ii/Tx8NO5dsZowZa4xZYIzZbYyZEXY+jYEx5ihjzCvGmC3GmI+MMReFnVNNsq6CihsbBEGL+FfPsJPJclOBPUA74ArgXmNM73BTahSmAm+HnUQjsBqYBEwLO5HGwBiTBzwP/AUoAa4BHjXGlIWaWA2ytYKSDDDGFAMXAz8LgmB7EAT/BGYB3wk3s+xmjLkMqAReDjuXbBcEwTNBEDwHbAw7l0aiHOgATA6CIBYEwSvA60T0M5+tFdSvjDEbjDGvG2PODDuZLFYGxIIgWJoUew9QC6qOjDGtgJuB68PORSQNU0Ps6IZOpDaysYKaCBwJHAE8APzZGNMt3JSyVgtgyxdiW4CWIeTSWNwC/CEIgpVhJyKSxmLgc+BHxph8Y8xgYCBQFG5a6WVdBRUEwfwgCLYFQbA7CIKZ2ObpN8POK0ttB1p9IdYK2BZCLlnPGNMHGARMDjsXkXSCINgLXAh8C1iLben/H7AqzLxqkhd2AhkQkL7ZKl9uKZBnjOkRBMGyeOw44IMQc8pmZwJdgE+NMWBbqLnGmF5BEJwQYl4iThAE/8a2mgAwxrwBzAwvo5plVQvKGNPGGHOuMabAGJNnjLkCGAC8FHZu2SgIgh3AM8DNxphiY8xpwAXAI+FmlrUeALoBfeJf9wF/Bc4NM6lsFv+cFwC52Mq+ID4STerIGHNs/DoWGWMmAO2BGSGnlVZWVVBAPnbI6XpgAzAOuDAIAs2FqrsxQCG2X/oxYHQQBGpB1UEQBDuDIFib+MJ2oVYFQbA+7Nyy2I3ALuAnwJXx8o2hZpT9vgOswX7mzwbOCYJgd7gppWe0YaGIiERRtrWgRESkiVAFJSIikaQKSkREIkkVlIiIRNJBDddsZpoHBRTXVy5ZaxubNwRBcNjBnqfrmZ6uZ2bpemZWXa8n6JrWpKZrelAVVAHFnGzOzlxWjcTs4KmKupyn65mermdm6XpmVl2vJ+ia1qSma6ouPhERiSRVUCIiEkmqoEREJJJUQYmISCSpghIRkUhSBSUiIpGkCkpERCJJFZSIiESSNv6S/ez7+omuvGaM3SLmvf5+s83j5o0AoMPUZi6WO+dfDZSdiDQlakGJiEgkZUULyuTZNHMPO/SAxy2Z0AWAWFG1i3Xu9jkARWOMi6290979/6vvEy62IbYDgJOfvN7Fuv/wza+QdXapHng8AHdP+72Ldc+317066biF/acDsKRvzMV+1OWU+k+wCdlxyckA3P7re13slkuvAiBY8J9QcsoWy3/TH4BFl/v3cb7JBWDAmGtcrPC5txo2MakTtaBERCSSVEGJiEgkhdrFl3tUDwCC5vkutnpgGwB2nbLDxUpa2/Jrxz3BwXphZ0sAbv/9eS42/5g/AfDJ3l0udtu6cwDo8Fpw0D8jW+0d3NeVf3zPIwCU5fvBD9Xxzr2P9+51sS3VzQE4vrl/nd3f6AdA4Zz3/blVVZlPOEN2XXCS/bdtrouVTJsXVjopPu9r7xtvWfFfIWeSHdaOP9WVXx32awD2Bs1SD2w6H+1GQy0oERGJpAZvQcXOPMGV75wxFdj/rj0T9gb+Af7Pp/w3AHk7/O1T/yfHAtDys30u1nyDbU0VLZif0VyiIrdVK1feMaAcgPGT/+RiZxVuj5dS71lmbPZ3qC/fYx9Cv/7Lu13sHw/dB0CvR8e62JETo9Mi+aLVA+zvWNSt0genhZRMQo5vzQWd7Hvx7NLFLvayOTXlFLG2d/TDeEpyMvu3pLHYc67tLam4wl+r0SfMBeC6Q5amHH/MQ+NcuWiN/dtZeepuF+v8R/sZavbSgswnm0QtKBERiSRVUCIiEkkN3sXXfMlqV36nqiMAZfnrDvp1rl9j5958vN3PjZrR7SkAtlT77rx2d79Rq9dr7M9PVz18hCu/3W/qQZ17c+nbrvxiC9vVNHLFYBeb2WU2AK16bfwqKTaYm4Y8CcDtiwZ/yZENJ7dbZ1dePND2N/Z560oX6/D2+ynnNHXbh9r5Yk9fdFdS1M53vK+y3EVmX2q7t4orPnCx5Ll9jdX6Uf1decqP7We+b3P/+CMn3j4ZsWKQix3f+lMA3rs6+ZrufzzAqSXDASh5KYMJp6EWlIiIRFKDt6D2rVnrylNuHwrAref5IeW5/24BwHtjpqScO2nDsa780aAiAGKVa1zs8v5jAFjxA39OV97LQNbZK7G23mN9/Mz6HFIfJI+sOBuABbOPcrH3v2fPmbOrwMVKF9gH+B9t9neo+f87x76uX6wj0vLNvi8/qIHlPbQzJbZreas0RzZtVUNOcuVf/Mq2NMvyU994Mx/000oO/7B2vSjZzCQNNKsadBwAT//0Ny7WIc/OC/lexTkuVnFHTwCK//qui80p6gTA3GfLXOzpHrNSft7Wd9sCUPKVMz8wtaBERCSSVEGJiEgkhbqSRMl0O1fmsD+3dbHYxk0A9D76uy72wQDblJ/1wEAXK61MbbabebY7r2t0p+A0iMTCr+AXf00s/Ap+hYjzF1/kYrmX2G7WNt/yw0V6PWLnNZVNXeliOSsXAnDIa/7n7b3VPnh9+lg/mei7Z9l+1qhsxVF9eh9XPqPgnyFmkl6X4tQBJh1nx9Ic2bStudKvUHJWYaLs55AlHvgfflfj79ZLtmasXxXmrQmJAQ5+uZehH9lVSfZd7FeFKdpg53wmDxBbfY19JDC/R+ogicSqPADd77d/E+q7s1wtKBERiaRIbLcR25B697h3a+qD/N5XfOjK6++N3zVV6y4zwZzYG4ANP/RrDCZW6XjHTwLnle29ANj4eEcXa7vZNjtbP+q3GGkd/7e2d0ntcv0d28br7EP/0jm1PLmeVQwpdOXS3KIQM9lfXhf7UPqSktQH0YWfbHblpv4uz/uanSbxwRnTXSyxYswi3yjg0zvtw/1iGueKMF+0bIodar/k235QWWII/VH/GOVi5RNWAOn/1iYbNfr5Gr836dYRrnzIyobpplILSkREIkkVlIiIRFIkuvjSOWqiX8Bw5DF2js70zi+72MCh/wNAyyeazq636eQU+e6qfb/eCsCb5c+42Cf79gDwwxv8TsGHvGZni5cWf+5ime5COql9BQArMvy6dZXXfVtKrGpxmxAy2d/K3xUDcFpzv7bBH7Z+zRYqt4aRUmTk9u7pyn3/VPNOwsOe8RMfuz3d+P8eLP+t38F6ybftChFbqv3gkaGLLweg5zj/NzS2LfX9n1Ns33sbL/HzSy9oYedO5eC7xMuftH9ru89o+NFnakGJiEgkRbYFFavc4sobR9vVDT6d5R/+/2TSwwD89FI/VDpYaB/rd7w1qaYPGvcqe7sG9nbll8rvSfn+1deOB6Dlc/7OMnrrKISjdEH9r8iWe6ifQrHuYvsAv+TSVS42t+wP8ZJfrePeqRfa/NY1raHSX1Rxvr92T7VdGC/5IeWXL7dDp8tuW+5ijXkwSW67UgBmXuQ/54kpI4lWE0Czcyri30uV06eXKx89bREAk9rdnXSEHeh02ruXuUjPX9rjwri2akGJiEgkqYISEZFIimwXX7Lq92wT87KbfuRif/zFHQC8e8rD/sD4s8PexX5n1x4P2sVk9328on6TDMmxt/iFHhPL4ScWfgUofO6tes8h39hul71Jvam5Jvpdq7tK/P1Z8QGOqz7Dr8wR5NqFSVcO8nO+9nSwE3FymvlOkL+fYeelJK9jujZmz/nZx75belO17YgpyvHntptvH2hH/wrWj00j7TYRz476TVI0H4BRK/1qMntH2OsZW/9pg+UWJlNgf9/kLTMSCn/g542aznZ+47JRX3OxwYPsii7jSx9wsU55diBEcldgLP5IxDzhtzGKVS77ipnXnVpQIiISSVnRgkoomeYHP4xdYoc+trrNP3B+7Ei7e9YHV/mtJco7Xg1Az5t8XRxb9nG95tkQKr9j7zJvbHeHi1XHt9F45+/+QWgn6v9Be2JGf3XSvdiLi2wOPYjGWny7q/JduTreNpl+w2QXmzW2T8o5CRPbPuTKOfEN8XYFe1xsdcz+/r9ff6aLDZp9HQBtFvo72/Z/txtzmgr/nl2/yN7Ftsv1yyEETXBzwuQh5W9MSnx+C1KOm7eqiyt3XFHz0PPGKKiyy8HM3+3fyyc3t++b52c/7mLVB9iOcfYu3zJaFu/yOKtwu4st2GPfr20ejsaCpmpBiYhIJKmCEhGRSMqqLr5k5nU7OGDnJaUu1m/YOADmT/RLxS8+y3bPXNFlsIttOb0hMqxf++ITvVvn+C6keVX2IeqRD6/2x2X45yZWrlh8x9FJ0XcAuOLjb7hI+bWfANGZl9L9yoWu3PtXdhBNx36f1ercOZ/73UXXv2AfPLf9wHfJNXvx7XjJx8pYkPI6iWvx2cRTXaxfc9uV8vj2I2qVS2O19Aa/IkqiyzidTrf5clMbRBJbZ1d++cXoq13sjvvsnKhjk9bWfnSrHSQxae75LlY2w640kbfOzy8tfcxubXRWx1dcbMQc+9rp3r9hUAtKREQiKWtbUAmJuwqAdnfbctWPfbuhyNhbiwe7/MXFhlxkH2AXPdu4luTfGGsBZH5IffJ6f0tuOwaAxRf4gSgv7LQreKye2t3FWm6O7ppoXX9a9wfA7fnqQ5qLBqxPid0452JXLqP+pwZERWJzzUl9nzvgcef8x65s0GJB0xoYkU6zl3zr5oauJ9V4XLr30bYL/PF/7WS31tgb+HZK4YrUbY7CpBaUiIhEkiooERGJpKzt4qs+3c5bWT7Uz5U4us8KwHfrJZuyya8GUPR8NB4AZtqE14cCUBYftPBVJbpfPk/aoXdRX9u1d/b7w1ys+Dw7r6wl0e3Wi7rOzze1R/7WrTPsygZH56f+/hPWDHDl1sPt7sJRGXSTrfYV+jZJuvmLXWfYLuyoLCitFpSIiERSVrSgTF87pHlp0npTD542E4ABBXvSnpOwO7BDf9/c1NUHq9dkOMMQxNd4y0m6x7jr9McAmEpZujNqpeLm/q789FV3AlCW76/7CW+NAKDDRR/W+WeIJBzfzL5/0w0tnzf9BFcu3dy0tx7JlJaPJ/Vy/Da8PGpLLSgREYkkVVAiIhJJkeviy+vaGYDlIzu42C+H2YUQL26xoVavccO6vq489y67B8chM6Ox+GHGxJ8pJz/gHFi4EYDrZpzoYt2m2+/nr93mYusGHgZAyTC/aOm4Ti8D8I0iP8Bi1o52AFz1/nkuduj9B9qYQg5WrrH3iJvL/AKgh78QVjYNY+VTfhWSfPNujce1f9V/3jU4IjO2XXZK0n9lZjBVfVILSkREIinUFlRel04AbDmxvYsNu/lFAEa1eaZWr3H9Gn9HMO8e23IqmeFnUB9S3chaTgdQYOz/zkXn3Odi/zzDDsNftvtwFxvZekWNr3Ht6jNc+cU37FD+Htdq+Hh9iQXxFnATuFVMTFv4XZ9HXSwxOGJLdZWL9XvBrvRSXqGBOJm25cjseqNlV7YiItJkqIISEZFIarAuvrz2totp0zT/kH1017kADG+5rlavMfYzv0/Gv+613U+HPuUXjyzZ1nS689q9ahfGnfh9P2/p9sNTf//EPLHTC1akfG/hbn9/MnzuNQCUjfQPTntoZYgGs7PfzrBTqHdVJXY+3ekFO5KiuQC8tLOTi5RdY7cvqXlfWKmrI+b691n+WHvt90Z4ERO1oEREJJLqpQW151w7WGHP+E0udkP3vwEwuHBH2nO+aF3Mr/82YNb1AJTfuNjFSipta6Gp3mXFli4HYNnQLi7Wa5zdsPHDS6cc8Nzyv40BoOc9/m6qbGH0h5w2Rolh5iINIbHRK8CMrXaz1+Et/cadO3vbAWvNVq4iCvTpEBGRSFIFJSIikVQvXXwrLrT13tJjnjzgcVMruwFw19zBLmZidhXU8kmfuFiPdXbnW80mT5W8e2738bZ8/vh+BzynDPsQOsLPRhu13bMPc+VYn6bTSd3q3bUAjFv1dRe7r+PcsNJp8ibffwkAwyfc5WLtf/YRABsrj/UHvvnvBs0rmVpQIiISSfXSgiobbVdyGDL6xC85Mn48b6XE1FqSxurwyX7riG9OtltKHEnNa9I1Fvs+qQBgVdJycEOo3d8IybwjHlkCwLALh7jYE93/AsDAnw93sZLLWwMQq9zSgNlZakGJiEgkqYISEZFIitx2GyIiUv9iG+z2PHsubutiR/32+wAsGnS/i51f/j1bCGGwhFpQIiISSWpBiYg0YYmWFECPEbZ8PslTVTTMXEREZD+qoEREJJJMENR+PQFjzHqgov7SyVqdgyA47MsP25+uZ410PTNL1zOz6nQ9Qdf0ANJe04OqoERERBqKuvhERCSSVEGJiEgkqYISEZFIUgUlIiKRpApKREQiSRWUiIhEkiooERGJJFVQIiISSaqgREQkkv4fn4UEUgYuwRMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 5 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_example(torch.stack(X_example), y_example, n=5);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preparing a validation split\n",
    "\n",
    "skorch can split the data for us automatically but since we are using `Dataset`s for their lazy-loading property there is no way skorch can do a stratified split automatically without exploring the data completely first (which it doesn't). \n",
    "\n",
    "If we want skorch to do a validation split for us we need to retrieve the `y` values from the dataset and pass these values to `net.fit` later on:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = np.array([y for x, y in iter(mnist_train)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Build Neural Network with PyTorch\n",
    "\n",
    "Simple, fully connected neural network with one hidden layer. Input layer has 784 dimensions (28x28), hidden layer has 98 (= 784 / 8) and output layer 10 neurons, representing digits 0 - 9."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch import nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A simple neural network classifier with linear layers and a final softmax in PyTorch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ClassifierModule(nn.Module):\n",
    "    def __init__(\n",
    "            self,\n",
    "            input_dim=MNIST_FLAT_DIM,\n",
    "            hidden_dim=98,\n",
    "            output_dim=10,\n",
    "            dropout=0.5,\n",
    "    ):\n",
    "        super(ClassifierModule, self).__init__()\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "        self.hidden = nn.Linear(input_dim, hidden_dim)\n",
    "        self.output = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, X, **kwargs):\n",
    "        X = X.reshape(-1, self.hidden.in_features)\n",
    "        X = F.relu(self.hidden(X))\n",
    "        X = self.dropout(X)\n",
    "        X = F.softmax(self.output(X), dim=-1)\n",
    "        return X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "skorch allows to use PyTorch with an sklearn API. We will train the classifier using the classic sklearn `.fit()`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch import NeuralNetClassifier\n",
    "from skorch.dataset import CVSplit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We might also add tensorboard logging. For that, skorch offers the `TensorBoard` callback, which automatically logs useful information to tensorboard\n",
    "\n",
    "**Note**: Using tensorboard requires installing the following Python packages: `tensorboard, future, pillow`\n",
    "\n",
    "After this, to start tensorboard, run:\n",
    "\n",
    "`$ tensorboard --logdir runs`\n",
    "\n",
    "in the directory you are running this notebook in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = []\n",
    "if USE_TENSORBOARD:\n",
    "    from torch.utils.tensorboard import SummaryWriter\n",
    "    from skorch.callbacks import TensorBoard\n",
    "    writer = SummaryWriter()\n",
    "    callbacks.append(TensorBoard(writer))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "net = NeuralNetClassifier(\n",
    "    ClassifierModule,\n",
    "    max_epochs=10,\n",
    "    iterator_train__num_workers=4,\n",
    "    iterator_valid__num_workers=4,\n",
    "    lr=0.1,\n",
    "    device=DEVICE,\n",
    "    callbacks=callbacks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.7908\u001b[0m       \u001b[32m0.9005\u001b[0m        \u001b[35m0.3620\u001b[0m  2.3784\n",
      "      2        \u001b[36m0.4249\u001b[0m       \u001b[32m0.9213\u001b[0m        \u001b[35m0.2846\u001b[0m  2.2981\n",
      "      3        \u001b[36m0.3557\u001b[0m       \u001b[32m0.9303\u001b[0m        \u001b[35m0.2411\u001b[0m  2.2295\n",
      "      4        \u001b[36m0.3192\u001b[0m       \u001b[32m0.9376\u001b[0m        \u001b[35m0.2147\u001b[0m  2.2887\n",
      "      5        \u001b[36m0.2877\u001b[0m       \u001b[32m0.9434\u001b[0m        \u001b[35m0.1970\u001b[0m  2.2926\n",
      "      6        \u001b[36m0.2676\u001b[0m       \u001b[32m0.9471\u001b[0m        \u001b[35m0.1809\u001b[0m  2.3752\n",
      "      7        \u001b[36m0.2534\u001b[0m       \u001b[32m0.9494\u001b[0m        \u001b[35m0.1704\u001b[0m  2.3644\n",
      "      8        \u001b[36m0.2413\u001b[0m       \u001b[32m0.9521\u001b[0m        \u001b[35m0.1602\u001b[0m  2.5879\n",
      "      9        \u001b[36m0.2295\u001b[0m       \u001b[32m0.9557\u001b[0m        \u001b[35m0.1519\u001b[0m  2.3586\n",
      "     10        \u001b[36m0.2189\u001b[0m       \u001b[32m0.9572\u001b[0m        \u001b[35m0.1464\u001b[0m  2.3270\n"
     ]
    }
   ],
   "source": [
    "net.fit(mnist_train, y=y_train);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred = net.predict(mnist_test)\n",
    "y_test = np.array([y for x, y in iter(mnist_test)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.958"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test, y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An accuracy of about 96% for a network with only one hidden layer is not too bad.\n",
    "\n",
    "Let's take a look at some predictions that went wrong.\n",
    "\n",
    "We compute the index of elements that are misclassified and plot a few of those to get an idea\n",
    "of what went wrong."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "error_mask = y_pred != y_test"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have the mask we need a way to access the images from the `mnist_test` dataset. Luckily, skorch provides a helper class that lets us slice arbitrary `Dataset` objects, `SlicedDataset`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "from skorch.helper import SliceDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_test_sliceable = SliceDataset(mnist_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_pred = torch.stack(list(mnist_test_sliceable[error_mask]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAABnCAYAAABVe9YVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjAsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+17YcXAAAUFklEQVR4nO3deXgURfoH8G/lBkIIIUCQK8QQAuhyKkZFQER+SsDjh4AH6IoXrOCKt6Au4ImKIqAugqiIKIKgIuK1iygsEEBFJQQCkUuQhOVICIQcvX/UTNXsdgIhzKSrJ9/P8/jQebtnpqxMpuatrkNYlgUiIiLThDhdACIiovKwgSIiIiOxgSIiIiOxgSIiIiOxgSIiIiOxgSIiIiOxgSIiIiO5soESQgwRQmQKIY4KIbYJIbo7XSY3Y336nxCitRDiuBDiXafLEgxYn/4jhFjuqcsCz39ZTpepImFOF+B0CSH6AHgOwGAAawE0cbZE7sb6DJjpADKcLkQQYX36192WZc10uhCn4roGCsB4ABMsy1rt+XmPk4UJAqxPPxNCDAFwCMAqAMkOF8f1WJ81l6u6+IQQoQC6AmgohMgWQuwWQkwTQtRyumxuxPr0PyFEDIAJAO5zuizBgPUZMM8IIfKEECuFED2dLkxFXNVAAWgMIBzAQADdAXQE0AnAOCcL5WKsT/+bCGCWZVm7nC5IkGB9+t9DAJIANAUwA8CnQoiznS1S+dzWQB3z/DvVsqy9lmXlAZgM4EoHy+RmrE8/EkJ0BHAZgJecLkswYH0GhmVZayzLyrcsq8iyrLcBrIShf/OuugdlWdZBIcRuAFyC3Q9Yn37XE0AigJ1CCACIBhAqhGhnWVZnB8vlVj3B+qwOFgDhdCHKI9y23YYQYgKAKwD0A1AM4BMAyy3LeszRgrkU69N/hBC1AcT4hO6H/IAdYVlWriOFcjHWp/8JIWIBdAPwLYASyNG7MwB0tizLuOHmrsqgPCYCiAewBcBxAPMBPOVoidyN9eknlmUVAij0/iyEKABwnB+mVcP6DIhwAE8CSAVQCmAzgKtNbJwAF2ZQRERUM7htkAQREdUQbKCIiMhIbKCIiMhIbKCIiMhIpzWKL0JEWlGoE6iyuFY+DuZZltXwdB/H+iwf69O/WJ/+VdX6BFinFamoTk+rgYpCHXQTvf1XqiDxtbVgR1Uex/osH+vTv1if/lXV+gRYpxWpqE7ZxUdEREZiA0VEREZiA0VEREZiA0VEREZiA0VEREZiA0VEREZiA0VEREZiA0VEREZy435QpxSW0Fgdn2h9VoXXhW/Zo46zHkkCAMRu0htLxmUeBwCEfPeDv4tYbUIbNwIAlOYe0MGyUodKQ0RuJcIj1HFR7w4AgP2dwlWspGMBAKB30hYV+2Z7CgCg+Wu6qQldvqHSr8kMioiIjOT6DOrwTReo4wNXyozn4U7LVGxYzNIKHzvrcAt1fG3dRQCA+tdF2a5Lb9rljMvplOiFMlvKPdZMxQrfkVll7Jx/Bfz1w1o2V8elv/8BALCKTwT8dcm91Dd1q0zFrJISh0pTM4nISHV8vPefAABl9+Sp2Ipz36jcEzVdDQAYn9pOhVZ1iKjoahtmUEREZCQ2UEREZCRXdPGFdGgLANg8Si9T/93lLwMAGoZm6OtOs70dXm+nz0/2rr1gkPHL2QCA7P6vq1j75LsBALHV8PqZD+pBKlZYAgAg5c6Mii6nGurQ0DR1/ObEyQCA9cd19/ALMwYBAJq8vEY/iIN9/E50PQcAEPViroota1257rz1RbLr/qW9l9vO5Y5p4fPTxkqXhxkUEREZyRUZ1NFWdQEAW654zSdaq8rP9/ohOaR87o7zKnV9PWRX+bWcVmu3M7/igkFy8Mr6AZNVLCZEZqlXonO1l2fb3E7quM46+d5JeGlVtZeDytdgw7/VcfqXowEAU3rNVbH1900FAIwcfImK7RqZCACw1v9aDSUMPt7BKFlTOqrYP/rJv9dW4dG267cUH1XHfT+7FwCQ/J4e8BSx5yAAoCSnvK2dDlWpjMygiIjISGygiIjISI528YU1awoAyHxIz9FpvEqu5BAzb7WKhRRZAIAtPvNndpXIW/zNw3TqeMsvNwMADmY20M+XIR8bu2qXilkFcsZzvUPu7bqrrBHXf+bI6+69RNa7t1sPAJ7I7eBIWQAgu9dsdbw6Td5cfyxjuIqFfP9jtZepIlaarKetw/RclHbjZbdJyb4/HClToJX+mqWOU+6Q/05Hioo9O1h2Gb/x3EsqVneR/D3e1XuYfp6t2wNZzKCSNU2+z3L6z/CJ2rv2rt7aFwBQ+EiCiqWsWmu7LhAz1ZhBERGRkao9gwqNraeOz/8sBwCwOP4TFbto3d22x0R+LoclP9DvFhXzfuMKbdtaxeKytsl/y/RaUF41aR66daHOVPpGv+o5qvqgkqq47qI1ttiXL10MAKiPwK9g8b98b/B2iZRZ3bYhOkNJ3RgDACg9cqRayxXWRH4rzbk1ScXeum0KAKBThP7+2LGRzBKa/X9wZlCnUvcD2aNyc90xKrZy/CsAgINdG6lYDDOocqkBEVP1gIjsdO/UE/0+KyiTq/Gcu0x/DqfeKz9rRf5PAS6lHTMoIiIyEhsoIiIyUrV18YVEyW6VogW6i+/R+H8AANp8NFLFUhfJOQ3lzRH3vZGqYplb/VjK4LD3Yr3ixtlh9q69sMLAvG5I7drquG6onNeyv1S/WPwy2QXrxPz/YWPvU8ffPTcdAJB1zasq1j1pMAAg+vlkFYvYJed1lGbnVPl1fbug918cDwAo6FOgYo91kIsZD4r2Hcxi/974UPsvAABz0cx2riZpMFN3D88cI7tF8zroLXJi5lV7kVzB27WXM8B3QIR8ny0+qgdGPP6a/CxOmaznCJbBOcygiIjISAHNoELr11fHmyfKIaNZbfW31vVF8t/UCfrGZnXfpA4m3voe9efFtnM9fh6ojps+F5gVFPYP1YMzHmkgs5TUb/XN1qQ/nBvKHfuB3iTtquH9AQCLUj5Wse86fAAAyHv7mIrlW/Kb+aEyvT3A1/lyrbK4MJ0FTfqqfzkvKKdEfNXzFRVqUU42W1lPfHsNACAF7l/HMLS1HhCyfajcXLTphXrz0H2H5coxdRfVVbG4r+RnRElSExVLjJC/s7MX6t+FFYDyuo0aEDFN/z2WNyDicJl8r08ZdbuKNVlm1uoqzKCIiMhIbKCIiMhIAe3i+/2mtuo46xq52OMnR3W336z0PgCA0txtgSxGjZH9mtyeYHjMN7ZztZ6qZ4v5W/MbzJ2D4ruLb2mv3wEAbV/R3Y/v958GAOgUobvh4st5nk4NNtliwwdOP8kr6+fr+fN1AIAXUz5UsS6RtgegyCqW52bfq2Jtp8i5fW7bYCK0QZw63n1LKgBg0ehJKub9hpxR1FTF6oTIvv9+acdVbPkxeeVnh3W31ZNb+wEAYjJ+9m+hXe730V0BADn9X/WJ2nORASPuAQBELbOvCmEKZlBERGSkgGZQ+d2O2WJTcnqr41pbmDmdqYM3643elqS94DnS39oXejLWsA16dQ1/DxsNbdgQANApdtcprjRL69F6tYsHP5fDa3dcp2+zL7lUZv0HynR9bjiWCAD4S6x+73ozHl+dVtwFAIj5Vj+26Q1yuHrLMN+/C3m+0NIZ3lV3yO0mEpfqIdVuy5y88tLbqOOoS+UmeJcvul/F2vz9AID/ni7inZIyYZDeImXNs3KrnZ61flCxQYUyOytop9fsK91kX0WmJijqp7cOWjD6ec+Rnm7iXSEi7WW9EsdZS6p/RZfTxQyKiIiMxAaKiIiMFNAuvnkX2WctL2j3roqkTZaz+1t9ors3QpdvAJ1aaLzcUuSyv65UsfJWjXjj9msBACFHf7Cd85fidnJ1g3HxX9jONX/LFZs2qwWJUz7XsTFIs11XdKXsSnmn5RUqFvVv++ybpA/l8/02/nwVW5jsXS3C/ns6b47uemm11Pyul1PxdjnNm/C8it04VnbtJb+rt9Ipr+uy7LjsjjoRLWzn3s/Xg6zeSFwCAPhliR5pMm6E3Ksj4ot1VSy5ywhZR0dHHFahlPA6tst2lMj3qM+UPoTW8y6QrOeRocyszmRmUEREZKSAfr09PzJcHRdbsmWu77OB3ebBcnhu8SDdap/zjby5XC9DX1fQzLP5nc8o5viNevsEr7w/yW8OjZfvV7HSIB2IYTWTM/AnNvrKds531YjolXIIrlMz7CP/0L8nJ9f08pfIpTIzaniqCz3fbC/uc/Ih0D+ckLWS/KK+uW/Wd9iq6TJhPQDg3UM6g4z7vHJD5Y/cIDcnfOG+v6vY+Nx2AIC1V+m1Eif3kEPT7x/7nop9OFNuVXLTtXepmBXEw9C9WVBG5/knva69Z/rEr3f7DD33zLJI+lJv3Nl2ghy0UrL9N/8V8gwwgyIiIiOxgSIiIiMFtIuv1ad6EcItarFCu3ARqo6zLntDHlxW9ddd+7C+ufrXTUMAAHHp7p8fkT/kAnV86cMrbefn5MvdWeuN1J1pJSWevYSFrpPQ2NgKX8MqKlLHItK+zEHpoUOeC0/eaejtZqyzcfNJrwta3c4FALzefLbtVGaxnjf1wCg55ykqz9zZ/FWxOjcRAHDw+wQVa36g4oVIfbclGT9hFgAgtyRGxdZeLc+X/LZDxer/thMAMPunfvqJ5suBKNe+o1dTWXztRQCCc2uezOe9c8C+rfJzbL98ljqeel5LAMCyfnrFjhJPPTuBGRQRERkpoBlUm7/ooc19P5TDP4dN+1TFanvW3EqvnativtlUVZ0fqb/df99pLgCg/fOjVezsB9w5jPePdJ3djG/4k+18XKgcLpr5tzjfKAAgJFTXSVbPWfhfoUJ+V3kit72KPR5vv7nc9Wl5Z7XJ+zozyukbZbsu74gcsFLnFJlWsNr/yAlbbGeJXEFiyMwHVaz5ErO2N/CXyBfkcPCnX3tHxZ7ZORQAEPuO/vs7fJPsFbh17CcqtqJArtm3/pZzVawsx74Gojr3oz43e5DMpv48X28AefVHsrdhcZrO0koP6WHZbjYqzb7uZk6x/BwY+NQDJ33s4R5yOH92L53lj6ovM9Rpt1+pYoljmUERERH9FzZQRERkpIB28VneG/QAwr+W8yLmpZ5lu+6VgUPUcWm4vJl/4f36pvGzCVXfRTTE0wY367C3ys9hirkXzvT5yT7Lvl9tmdr36z3Tdu5USi05sKJVpO5u/awwGgCQXaRvdK97VG5L8dhtHVXsjmj7Df6zZkTYYsEuLKGxOv6uy1ueIz0X8P/myS6XpCeDs1vPl/fvfczSm1Rs6zNy3uMlN+h5ep+3nwwA+OqY3il38cBuAICy7Iq79Sri7e7zdvUBwKsfyxVtPvpILz4bcpV8v5fl55/2a5iu9z/lNhqtZ5z8VkZsdhd50Mt+rqTlcXvQAcygiIjISEYslFZnwRpb7NMOeh20Z4fKDMp3S4IuK0YAAFrO1IMq8kYXAgDWnafX+wsmt84apY5/HjnNdv7TQjksd2W+vhm855h9SHnGSnkTuuEG+wCG+stz1LFVVw50sPbsU7GF6ZcDAGrfuUfFJjb6EQDwzIF2Khb1k7yxGgyrIlTWpsdbquNIEW47L2rgeJHUv2Wp49u69QAAXN1MD/DpOVmuz9d0jr6uNE+/B6vKd+DEjQ/J15g/6QUVm7tKZlPLu+k1QcoKC8/4dU1gFVZuoFnRgwcDXJIzxwyKiIiMxAaKiIiMZEQXX3lafKHn/EBOn0BtoW+8Z/aQc3mGtuyjYksTvds92Nvdnfv03KDW+M1v5axOLSbpLQT6rLnddj5qp2eVh316oEPpEXsan4SKb56W+P6wz34+er7cKmFremcdlD2GmL1Rd8sm5wZuew/TlPaSdfHrAN9uV3s3S3i+fWBLsCs9qN9/+3rLLuN9pbpbLeG4HDASyK7guu/L9+zw7SNUbM5CubLN+i8HqNjhSzyfOYZtOeEP3lVhsqbruWXZ53q3Q9Kfl98ck+/bNk/rRZ6drA1mUEREZCRjM6jwdXrdrAs2XA8AWN15nu26OYm+203I9rbI0mudpXvW4ksdrbfdcOv3I6tYDxLxDuP1VZ3/X23H6i1N0l+8EQDQJsf9dVwVEb/sAgDMy2+hYsNi9tiua/hjsS1Wk5QdtW+RU63W6pVR+j4rh/yvfnSKinUcJ4dnt5jgnmkAU1f1BgCMSdd7EU3pIweJLfhXVxWLi5ADQJY18Z2CYs9Pxky9EwCQsMmMOmAGRURERmIDRURERjK2i893hnfCKLnwZP839Q3NRxPlYpBpkbozaWFBPABg7NLBKpZ8r7xBWpO6nKpDya7d+oddzpXDBGVHjgAA9hb7zjmTXXwHy/SM/NprZTcM34vOazRddmF1jL9HxX68U3b3XbxbLywd96bZC0u3GbURADDufD344clGsitzQJ0VJ33sCs9bc9TUkSqWMMWs/19mUEREZCRjMyhfasOsS3Vs9GjZ6uefd0zFUsflAQCSd6yutrIRneh+DgDgoQYzbOf6bhiujhvl1dDNGw3W8im9juS96d0BAI898raKTX8zxfYYk3g3GP34ve4qVna9nM7wdOONKjZuv8yw5m/W6xEmj5efnQmZZgyIKA8zKCIiMhIbKCIiMpIruvjK0/gVmZY29omVlH8pkWMiFtR3ugh0Er5bAm3rJlfzfTX0HJ8r7Dsjm+isSbqbbv0kmXf0RUfbda2gu/3cMFiHGRQRERnJtRkUkSlqZcrNMO/be4GKPdjonwCAmJ1mbPxGleBZg88KwrX43IoZFBERGYkNFBERGYldfERnqGTP7wCAbf0TVOy2pncBAELW1ZxtR4j8jRkUEREZiRkUkZ+U7PXZ4XFvObs9EtFpYQZFRERGYgNFRERGEpZlVf5iIXIB7AhccVyrpWVZDU/3QazPCrE+/Yv16V9Vqk+AdXoS5dbpaTVQRERE1YVdfEREZCQ2UEREZCQ2UEREZCQ2UEREZCQ2UEREZCQ2UEREZCQ2UEREZCQ2UEREZCQ2UEREZKT/AD2IicCGHsOpAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 5 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_example(X_pred[:5], y_pred[error_mask][:5]);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If tensorboard was enabled, here is how the metrics could look like:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![tensorboard scalars](../assets/tensorboard_scalars.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Convolutional Network\n",
    "\n",
    "Next we want to turn it up a notch and use a convolutional neural network which is far better\n",
    "suited for images than simple densely connected layers.\n",
    "\n",
    "PyTorch expects a 4 dimensional tensor as input for its 2D convolution layer. The dimensions represent:\n",
    "\n",
    "* Batch size\n",
    "* Number of channels\n",
    "* Height\n",
    "* Width\n",
    "\n",
    "MNIST data only has one channel since there is no color information. As stated above, each MNIST vector represents a 28x28 pixel image. Hence, the resulting shape for the input tensor needs to be `(x, 1, 28, 28)` where `x` is the batch size and automatically provided by the data loader.\n",
    "\n",
    "Luckily, our data is already formated that way:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([1, 28, 28])"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_example[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let us define the convolutional neural network module using PyTorch:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Cnn(nn.Module):\n",
    "    def __init__(self, dropout=0.5):\n",
    "        super(Cnn, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)\n",
    "        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)\n",
    "        self.conv2_drop = nn.Dropout2d(p=dropout)\n",
    "        self.fc1 = nn.Linear(1600, 100) # 1600 = number channels * width * height\n",
    "        self.fc2 = nn.Linear(100, 10)\n",
    "        self.fc1_drop = nn.Dropout(p=dropout)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = torch.relu(F.max_pool2d(self.conv1(x), 2))\n",
    "        x = torch.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n",
    "        \n",
    "        # flatten over channel, height and width = 1600\n",
    "        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))\n",
    "        \n",
    "        x = torch.relu(self.fc1_drop(self.fc1(x)))\n",
    "        x = torch.softmax(self.fc2(x), dim=-1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also want to extend tensorboard logging by two more features:\n",
    "\n",
    "1. Add the predictions for the misclassified images to tensorboard.\n",
    "    \n",
    "    To do this, we subclass the `TensorBoard` callback and call `self.writer.add_figure` with our produced images. When subclassing, don't forget to call `super()` or the other logged metrics won't show.\n",
    "\n",
    "\n",
    "2. Add a graph of the module\n",
    "    \n",
    "    To do this, we use the summary writer's ability to add a traced graph of our module to tensorboard by calling `add_graph`. We also make sure to only call this on the very first batch by inspecting the `self.first_batch_` attribute on `TensorBoard`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "callbacks = []\n",
    "if USE_TENSORBOARD:\n",
    "    from torch.utils.tensorboard import SummaryWriter\n",
    "    from skorch.callbacks import TensorBoard\n",
    "    writer = SummaryWriter()\n",
    "\n",
    "    class MyTensorBoard(TensorBoard):\n",
    "        def __init__(self, *args, X, **kwargs):\n",
    "            self.X = X\n",
    "            super().__init__(*args, **kwargs)\n",
    "\n",
    "        def add_graph(self, module, X):\n",
    "            \"\"\"\"Add a graph to tensorboard\n",
    "\n",
    "            This requires to run the module with a sample from the\n",
    "            dataset.\n",
    "\n",
    "            \"\"\"\n",
    "            self.writer.add_graph(module, X.to(DEVICE))\n",
    "\n",
    "        def on_batch_begin(self, net, X, **kwargs):\n",
    "            if self.first_batch_:\n",
    "                # only add graph on very first batch\n",
    "                self.add_graph(net.module_, X)\n",
    "                \n",
    "        def add_figure(self, net):\n",
    "            # show how difficult images were classified\n",
    "            epoch = net.history[-1, 'epoch']\n",
    "            y_pred = net.predict(self.X)\n",
    "            fig = plot_example(self.X, y_pred)\n",
    "            self.writer.add_figure('difficult images', fig, global_step=epoch)\n",
    "\n",
    "        def on_epoch_end(self, net, **kwargs):\n",
    "            self.add_figure(net)\n",
    "            super().on_epoch_end(net, **kwargs)  # call super last\n",
    "\n",
    "    X_difficult = torch.stack(list(mnist_test_sliceable[error_mask][:15]))\n",
    "    callbacks.append(MyTensorBoard(writer, X=X_difficult))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As before we can wrap skorch's `NeuralNetClassifier` around our module and start training it like every other sklearn model using `.fit`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(0)\n",
    "\n",
    "cnn = NeuralNetClassifier(\n",
    "    Cnn,\n",
    "    max_epochs=10,\n",
    "    lr=0.0002,\n",
    "    optimizer=torch.optim.Adam,\n",
    "    device=DEVICE,\n",
    "    iterator_train__num_workers=4,\n",
    "    iterator_valid__num_workers=4,\n",
    "    callbacks=callbacks,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  epoch    train_loss    valid_acc    valid_loss     dur\n",
      "-------  ------------  -----------  ------------  ------\n",
      "      1        \u001b[36m0.9300\u001b[0m       \u001b[32m0.9297\u001b[0m        \u001b[35m0.2459\u001b[0m  2.9154\n",
      "      2        \u001b[36m0.3148\u001b[0m       \u001b[32m0.9541\u001b[0m        \u001b[35m0.1518\u001b[0m  2.9141\n",
      "      3        \u001b[36m0.2208\u001b[0m       \u001b[32m0.9663\u001b[0m        \u001b[35m0.1160\u001b[0m  3.0988\n",
      "      4        \u001b[36m0.1779\u001b[0m       \u001b[32m0.9701\u001b[0m        \u001b[35m0.0990\u001b[0m  2.9270\n",
      "      5        \u001b[36m0.1549\u001b[0m       \u001b[32m0.9743\u001b[0m        \u001b[35m0.0890\u001b[0m  3.0307\n",
      "      6        \u001b[36m0.1406\u001b[0m       \u001b[32m0.9759\u001b[0m        \u001b[35m0.0800\u001b[0m  2.9676\n",
      "      7        \u001b[36m0.1282\u001b[0m       \u001b[32m0.9780\u001b[0m        \u001b[35m0.0734\u001b[0m  2.9617\n",
      "      8        \u001b[36m0.1143\u001b[0m       \u001b[32m0.9795\u001b[0m        \u001b[35m0.0691\u001b[0m  2.9718\n",
      "      9        \u001b[36m0.1071\u001b[0m       \u001b[32m0.9807\u001b[0m        \u001b[35m0.0640\u001b[0m  3.0400\n",
      "     10        \u001b[36m0.1043\u001b[0m       \u001b[32m0.9816\u001b[0m        \u001b[35m0.0610\u001b[0m  2.9902\n"
     ]
    }
   ],
   "source": [
    "cnn.fit(mnist_train, y=y_train);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_pred_cnn = cnn.predict(mnist_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9856"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test, y_pred_cnn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "An accuracy of >98% should suffice for this example!\n",
    "\n",
    "Let's see how we fare on the examples that went wrong before:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.7261904761904762"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accuracy_score(y_test[error_mask], y_pred_cnn[error_mask])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great success! The majority of the previously misclassified images are now correctly identified."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On tensorboard, in the \"IMAGES\" section, we can see how well the CNN classified the difficult images, and how that changed over the epochs:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../assets/tensorboard_digits.png\" alt=\"tensorboard digits\" style=\"width: 500px;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the \"GRAPHS\" section, we can see the graph of our module."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<img src=\"../assets/tensorboard_graph.png\" alt=\"tensorboard module graph\" style=\"width: 500px;\"/>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grid searching parameter configurations\n",
    "\n",
    "Finally we want to show an example of how to use sklearn grid search when using torch `Dataset` instances.\n",
    "\n",
    "When doing k-fold validation grid search we have the same problem as before that sklearn is only able to do (stratified) splits when the data is sliceable. While skorch knows how to deal with PyTorch `Dataset` objects and only needs `y` to be known beforehand, sklearn doesn't know how to deal with `Dataset`s and needs a wrapper that makes them sliceable.\n",
    "\n",
    "Fortunately, we already know that skorch provides such a helper: `SliceDataset`.\n",
    "\n",
    "What is left to do is to define our parameter search space and run the grid search with a sliceable instance of `mnist_train`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=Cnn(\n",
       "    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (conv2_drop): Dropout2d(p=0.5)\n",
       "    (fc1): Linear(in_features=1600, out_features=100, bias=True)\n",
       "    (fc2): Linear(in_features=100, out_features=10, bias=True)\n",
       "    (fc1_drop): Dropout(p=0.5)\n",
       "  ),\n",
       ")"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cnn.set_params(max_epochs=2, verbose=False, train_split=False, callbacks=[])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "params = {\n",
    "    'module__dropout': [0, 0.5, 0.8],\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The parameter we are interested in here is the dropout rate. We want to see which of the values (no dropout, 50%, 80%) is the best choice for our network.\n",
    "\n",
    "Additionally:\n",
    "\n",
    "- We use only two epochs (`max_epochs=2`) for each `.fit` (only to reduce execution time, normally we wouldn't change this and possibly add an `EarlyStopping` callback).\n",
    "- Disable the network print output (`verbose=False`)\n",
    "- Disable the internal train/validation split (`train_split=False`) since the grid search will do k-fold validation anyway\n",
    "- Turn off tensorboard logging (`callbacks=[]`)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "cnn.initialize();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "gs = GridSearchCV(cnn, param_grid=params, scoring='accuracy', verbose=1, cv=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_train_sliceable = SliceDataset(mnist_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fitting 3 folds for each of 3 candidates, totalling 9 fits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
      "[Parallel(n_jobs=1)]: Done   9 out of   9 | elapsed:  1.1min finished\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "GridSearchCV(cv=3, error_score='raise-deprecating',\n",
       "       estimator=<class 'skorch.classifier.NeuralNetClassifier'>[initialized](\n",
       "  module_=Cnn(\n",
       "    (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))\n",
       "    (conv2_drop): Dropout2d(p=0.5)\n",
       "    (fc1): Linear(in_features=1600, out_features=100, bias=True)\n",
       "    (fc2): Linear(in_features=100, out_features=10, bias=True)\n",
       "    (fc1_drop): Dropout(p=0.5)\n",
       "  ),\n",
       "),\n",
       "       fit_params=None, iid='warn', n_jobs=None,\n",
       "       param_grid={'module__dropout': [0, 0.5, 0.8]},\n",
       "       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',\n",
       "       scoring='accuracy', verbose=1)"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gs.fit(mnist_train_sliceable, y_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After running the grid search we now know the best configuration in our search space:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'module__dropout': 0}"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gs.best_params_"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
